#!/usr/bin/env python3

from garage import wrap_experiment
from garage.envs import GymEnv, normalize
from garage.experiment.deterministic import set_seed
from garage.sampler import RaySampler
from garage.torch.algos import SHARP
from garage.torch.optimizers import OptimizerWrapper
from garage.torch.policies import GaussianMLPPolicy
from garage.torch.value_functions import GaussianMLPValueFunction
from garage.trainer import Trainer
from garage.torch.optimizers.SHARP_optimizer import SHARPOptimizer


import torch

a = 0.05
b = 5

def run_task(seed):
    @wrap_experiment(log_dir="/root/Data/jmlr/halfCheetah-SHARP-a={}-b={}-seed={}".format(a, b,seed), archive_launch_repo=False)
    def sharp_halfCheetah(ctxt=None, seed=43):
        """

        Args:
            ctxt (garage.experiment.ExperimentContext): The experiment
                configuration used by LocalRunner to create the snapshotter.
            seed (int): Used to seed the random number generator to produce
                determinism.
        """
        set_seed(seed)
        runner = Trainer(ctxt)

        n_epochs = 1000
        sampler_batch_size = 10000

        env = GymEnv('HalfCheetah-v2')
        env._env.seed(seed)
        env.action_space.seed(seed)
        policy = GaussianMLPPolicy(env.spec,
                                   hidden_sizes=[64, 64], )

        value_function = GaussianMLPValueFunction(env_spec=env.spec,
                                                  hidden_sizes=[32, 32],
                                                  hidden_nonlinearity=torch.tanh,
                                                  output_nonlinearity=None)

        sampler = RaySampler(agents=policy,
                             envs=env,
                             max_episode_length=500,
                             )

        policy_optimizer = OptimizerWrapper((SHARPOptimizer, {"a": a,
                                                              "b": b}), policy)

        algo = SHARP(env_spec=env.spec,
                     policy=policy,
                     value_function=value_function,
                     sampler=sampler,
                     discount=0.99,
                     center_adv=False,
                     policy_optimizer=policy_optimizer,
                     neural_baseline=True

                     )

        runner.setup(algo, env)
        runner.train(n_epochs=n_epochs, batch_size=sampler_batch_size)


    sharp_halfCheetah(seed=seed)


# seeds = [14, 33, 3, 4, 49, ]
seeds = [7, 8, 21, 28,35,41,16,10,27,1]

for seed in seeds:
    run_task(seed=seed)